import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

import wandb
from wandb.util import get_module

if TYPE_CHECKING:
    np_array = get_module("numpy.array")
    torch_float_tensor = get_module("torch.FloatTensor")


def chunkify(input_list, chunk_size) -> List:
    chunk_size = max(1, chunk_size)
    return [
        input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size)
    ]


def get_updated_kwargs(
    pipeline: Any, args: Sequence[Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
    pipeline_call_parameters = list(
        inspect.signature(pipeline.__call__).parameters.items()
    )
    for idx, arg in enumerate(args):
        kwargs[pipeline_call_parameters[idx][0]] = arg
    for pipeline_parameter in pipeline_call_parameters:
        if pipeline_parameter[0] not in kwargs:
            kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default
    if "generator" in kwargs:
        generator = kwargs["generator"]
        kwargs["generator"] = (
            {
                "seed": generator.initial_seed(),
                "device": generator.device,
                "random_state": generator.get_state().cpu().numpy().tolist(),
            }
            if generator is not None
            else None
        )
    if "ip_adapter_image" in kwargs:
        if kwargs["ip_adapter_image"] is not None:
            wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])})
    return kwargs


def postprocess_pils_to_np(image: List) -> "np_array":
    np = get_module(
        "numpy",
        required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
    )
    return np.stack(
        [np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image],
        axis=0,
    )


def postprocess_np_arrays_for_video(
    images: List["np_array"], normalize: Optional[bool] = False
) -> "np_array":
    np = get_module(
        "numpy",
        required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.",
    )
    images = [(img * 255).astype("uint8") for img in images] if normalize else images
    return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2))


def decode_sdxl_t2i_latents(pipeline: Any, latents: "torch_float_tensor") -> List:
    """Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl).

    Arguments:
        pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from
            [`diffusers`](https://huggingface.co/docs/diffusers).
        latents (torch.FloatTensor): The generated latents.

    Returns:
        List of `PIL` images corresponding to the generated latents.
    """
    torch = get_module(
        "torch",
        required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.",
    )
    with torch.no_grad():
        needs_upcasting = (
            pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
        )
        if needs_upcasting:
            pipeline.upcast_vae()
            latents = latents.to(
                next(iter(pipeline.vae.post_quant_conv.parameters())).dtype
            )
        images = pipeline.vae.decode(
            latents / pipeline.vae.config.scaling_factor, return_dict=False
        )[0]
        if needs_upcasting:
            pipeline.vae.to(dtype=torch.float16)
        if pipeline.watermark is not None:
            images = pipeline.watermark.apply_watermark(images)
        images = pipeline.image_processor.postprocess(images, output_type="pil")
        pipeline.maybe_free_model_hooks()
        return images
